# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This config defines the architectural configurations of the Hugging Face version of a model.
"""


import transformers

gemma3_4b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 256,
        "hidden_activation": "gelu",
        "hidden_size": 2560,
        "initializer_range": 0.02,
        "intermediate_size": 10240,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 8,
        "num_hidden_layers": 34,
        "num_key_value_heads": 4,
        "query_pre_attn_scalar": 256,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)

gemma3_12b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 256,
        "hidden_activation": "gelu",
        "hidden_size": 3840,
        "initializer_range": 0.02,
        "intermediate_size": 15360,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 16,
        "num_hidden_layers": 48,
        "num_key_value_heads": 8,
        "query_pre_attn_scalar": 256,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)

gemma3_27b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 128,
        "hidden_activation": "gelu",
        "hidden_size": 5376,
        "initializer_range": 0.02,
        "intermediate_size": 21504,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 32,
        "num_hidden_layers": 62,
        "num_key_value_heads": 16,
        "query_pre_attn_scalar": 168,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)


gemma2_2b_config = transformers.Gemma2Config(
    num_hidden_layers=26,
    num_attention_heads=8,
    num_key_value_heads=4,
    hidden_size=2304,
    intermediate_size=9216,
)

gemma2_9b_config = transformers.Gemma2Config(
    num_hidden_layers=42,
    num_attention_heads=16,
    num_key_value_heads=8,
    hidden_size=3584,
    intermediate_size=14336,
    final_logit_softcapping=30.0,
    attn_logit_softcapping=50.0,
    head_dim=256,
    sliding_window=4096,
    query_pre_attn_scalar=224,
)

gemma2_27b_config = transformers.Gemma2Config(
    num_hidden_layers=46,
    num_attention_heads=32,
    num_key_value_heads=16,
    hidden_size=4608,
    intermediate_size=36864,
    final_logit_softcapping=30.0,
    attn_logit_softcapping=50.0,
    head_dim=128,
    sliding_window=4096,
    query_pre_attn_scalar=144,
)

qwen3_0_6b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=1024,
    intermediate_size=3072,
    num_hidden_layers=28,
    num_attention_heads=16,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
)

qwen3_4b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=2560,
    intermediate_size=9728,
    num_hidden_layers=36,
    num_attention_heads=32,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
)

qwen3_8b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=4096,
    intermediate_size=12288,
    num_hidden_layers=36,
    num_attention_heads=32,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)

qwen3_14b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=5120,
    intermediate_size=17408,
    num_hidden_layers=40,
    num_attention_heads=40,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)

qwen3_32b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=5120,
    intermediate_size=25600,
    num_hidden_layers=64,
    num_attention_heads=64,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)


llama31_8b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=4096,
    intermediate_size=14336,
    num_hidden_layers=32,
    num_attention_heads=32,
    num_key_value_heads=8,
    max_position_embeddings=131072,
    head_dim=128,
    rms_norm_eps=1e-5,
    bos_token_id=128000,
    eos_token_id=128001,
    attention_bias=False,
    attention_dropout=0.0,
    hidden_act="silu",
    initializer_range=0.02,
    mlp_bias=False,
    model_type="llama",
    pretraining_tp=1,
    rope_scaling={
        "factor": 8.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    rope_theta=500000.0,
    tie_word_embeddings=False,
    use_cache=True,
)

llama31_70b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=8192,
    intermediate_size=28672,
    num_hidden_layers=80,
    num_attention_heads=64,
    head_dim=128,
    num_key_value_heads=8,
    max_position_embeddings=131072,
    rms_norm_eps=1e-05,
    bos_token_id=128000,
    eos_token_id=[128001, 128008, 128009],
    rope_scaling={
        "factor": 8.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    rope_theta=500000.0,
    tie_word_embeddings=False,
)

llama31_405b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=16384,
    intermediate_size=53248,
    num_hidden_layers=126,
    num_attention_heads=128,
    num_key_value_heads=8,
    head_dim=128,
    max_position_embeddings=131072,
    rms_norm_eps=1e-05,
    bos_token_id=128000,
    eos_token_id=128001,
)

qwen3_30b_a3b_thinking_2507_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=151643,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=2048,
    initializer_range=0.02,
    intermediate_size=6144,
    max_position_embeddings=262144,
    max_window_layers=48,
    model_type="qwen3_moe",
    moe_intermediate_size=768,
    norm_topk_prob=True,
    num_attention_heads=32,
    num_experts=128,
    num_experts_per_tok=8,
    num_hidden_layers=48,
    num_key_value_heads=4,
    output_router_logits=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=10000000,
    router_aux_loss_coef=0.001,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    use_cache=True,
    use_sliding_window=False,
    vocab_size=151936,
)

qwen3_235b_a22b_thinking_2507_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=151643,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=4096,
    initializer_range=0.02,
    intermediate_size=12288,
    max_position_embeddings=262144,
    max_window_layers=94,
    mlp_only_layers=[],
    model_type="qwen3_moe",
    moe_intermediate_size=1536,
    norm_topk_prob=True,
    num_attention_heads=64,
    num_experts=128,
    num_experts_per_tok=8,
    num_hidden_layers=94,
    num_key_value_heads=4,
    output_router_logits=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=5000000.0,
    router_aux_loss_coef=0.001,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    transformers_version="4.51.0",
    use_cache=True,
    use_sliding_window=False,
    vocab_size=151936,
)

qwen3_coder_480b_a35b_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_dropout=0.0,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=6144,
    initializer_range=0.02,
    intermediate_size=8192,
    max_position_embeddings=262144,
    max_window_layers=62,
    mlp_only_layers=[],
    model_type="qwen3_moe",
    moe_intermediate_size=2560,
    norm_topk_prob=True,
    num_attention_heads=96,
    num_experts=160,
    num_experts_per_tok=8,
    num_hidden_layers=62,
    num_key_value_heads=8,
    output_router_logits=False,
    qkv_bias=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=10000000,
    router_aux_loss_coef=0.0,
    shared_expert_intermediate_size=0,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    transformers_version="4.51.0",
    use_cache=True,
    use_qk_norm=True,
    use_sliding_window=False,
    vocab_size=151936,
)

# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
# remove fp8 quantization_config, since we are using bf16
deepseek3_671b_dict = {
    "architectures": ["DeepseekV3ForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "auto_map": {
        "AutoConfig": "configuration_deepseek.DeepseekV3Config",
        "AutoModel": "modeling_deepseek.DeepseekV3Model",
        "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM",
    },
    "bos_token_id": 0,
    "eos_token_id": 1,
    "ep_size": 1,
    "first_k_dense_replace": 3,
    "hidden_act": "silu",
    "hidden_size": 7168,
    "initializer_range": 0.02,
    "intermediate_size": 18432,
    "kv_lora_rank": 512,
    "max_position_embeddings": 163840,
    "model_type": "deepseek_v3",
    "moe_intermediate_size": 2048,
    "moe_layer_freq": 1,
    "n_group": 8,
    "n_routed_experts": 256,
    "n_shared_experts": 1,
    "norm_topk_prob": True,
    "num_attention_heads": 128,
    "num_experts_per_tok": 8,
    "num_hidden_layers": 61,
    "num_key_value_heads": 128,
    "num_nextn_predict_layers": 1,
    "q_lora_rank": 1536,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "rms_norm_eps": 1e-06,
    "rope_scaling": {
        "beta_fast": 32,
        "beta_slow": 1,
        "factor": 40,
        "mscale": 1.0,
        "mscale_all_dim": 1.0,
        "original_max_position_embeddings": 4096,
        "type": "yarn",
    },
    "rope_theta": 10000,
    "routed_scaling_factor": 2.5,
    "scoring_func": "sigmoid",
    "tie_word_embeddings": False,
    "topk_group": 4,
    "topk_method": "noaux_tc",
    "torch_dtype": "bfloat16",
    "transformers_version": "4.33.1",
    "use_cache": True,
    "v_head_dim": 128,
    "vocab_size": 129280,
}
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)

# copy from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json
# remove mxfp4 quantization_config, since we are using bf16
gpt_oss_20b_dict = {
    "architectures": ["GptOssForCausalLM"],
    "attention_bias": True,
    "attention_dropout": 0.0,
    "eos_token_id": 200002,
    "experts_per_token": 4,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2880,
    "initial_context_length": 4096,
    "initializer_range": 0.02,
    "intermediate_size": 2880,
    "layer_types": [
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
    ],
    "max_position_embeddings": 131072,
    "model_type": "gpt_oss",
    "num_attention_heads": 64,
    "num_experts_per_tok": 4,
    "num_hidden_layers": 24,
    "num_key_value_heads": 8,
    "num_local_experts": 32,
    "output_router_logits": False,
    "pad_token_id": 199999,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "beta_fast": 32.0,
        "beta_slow": 1.0,
        "factor": 32.0,
        "original_max_position_embeddings": 4096,
        "rope_type": "yarn",
        "truncate": False,
    },
    "rope_theta": 150000,
    "router_aux_loss_coef": 0.9,
    "sliding_window": 128,
    "swiglu_limit": 7.0,
    "tie_word_embeddings": False,
    "transformers_version": "4.55.0.dev0",
    "use_cache": True,
    "vocab_size": 201088,
}
gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict)

# copy from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json
# remove mxfp4 quantization_config, since we are using bf16
gpt_oss_120b_dict = {
    "architectures": ["GptOssForCausalLM"],
    "attention_bias": True,
    "attention_dropout": 0.0,
    "eos_token_id": 200002,
    "experts_per_token": 4,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2880,
    "initial_context_length": 4096,
    "initializer_range": 0.02,
    "intermediate_size": 2880,
    "layer_types": [
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "full_attention",
    ],
    "max_position_embeddings": 131072,
    "model_type": "gpt_oss",
    "num_attention_heads": 64,
    "num_experts_per_tok": 4,
    "num_hidden_layers": 36,
    "num_key_value_heads": 8,
    "num_local_experts": 128,
    "output_router_logits": False,
    "pad_token_id": 199999,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "beta_fast": 32.0,
        "beta_slow": 1.0,
        "factor": 32.0,
        "original_max_position_embeddings": 4096,
        "rope_type": "yarn",
        "truncate": False,
    },
    "rope_theta": 150000,
    "router_aux_loss_coef": 0.9,
    "sliding_window": 128,
    "swiglu_limit": 7.0,
    "tie_word_embeddings": False,
    "transformers_version": "4.55.0.dev0",
    "use_cache": True,
    "vocab_size": 201088,
}
gpt_oss_120b_config = transformers.GptOssConfig(**gpt_oss_120b_dict)


qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
    # TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
    architectures=["Qwen3OmniMoeForConditionalGeneration"],
    thinker_config={
        "text_config": {
            "num_hidden_layers": 48,
            "num_experts": 128,
        }
    },
)

# {maxtext model name: hf model config}
HF_MODEL_CONFIGS = {
    "gemma2-2b": gemma2_2b_config,
    "gemma2-9b": gemma2_9b_config,
    "gemma2-27b": gemma2_27b_config,
    "gemma3-4b": gemma3_4b_config,
    "gemma3-12b": gemma3_12b_config,
    "gemma3-27b": gemma3_27b_config,
    "qwen3-0.6b": qwen3_0_6b_config,
    "qwen3-4b": qwen3_4b_config,
    "qwen3-4b-thinking-2507": qwen3_4b_config,
    "qwen3-8b": qwen3_8b_config,
    "qwen3-14b": qwen3_14b_config,
    "qwen3-32b": qwen3_32b_config,
    "llama3.1-8b": llama31_8b_config,
    "llama3.1-8b-Instruct": llama31_8b_config,
    "llama3.1-70b": llama31_70b_config,
    "llama3.1-405b": llama31_405b_config,
    "qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config,
    "qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
    "qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
    "deepseek3-671b": deepseek3_671b_config,
    "gpt-oss-20b": gpt_oss_20b_config,
    "gpt-oss-120b": gpt_oss_120b_config,
    "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
}
